{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install lightly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Note: The model and training settings do not follow the reference settings\n",
    "from the paper. The settings are chosen such that the example can easily be\n",
    "run on a small dataset with a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "import torchvision\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.loss import SwaVLoss\n",
    "from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes\n",
    "from lightly.models.modules.memory_bank import MemoryBankModule\n",
    "from lightly.transforms.swav_transform import SwaVTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SwaV(pl.LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        resnet = torchvision.models.resnet18()\n",
    "        self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n",
    "        self.projection_head = SwaVProjectionHead(512, 512, 128)\n",
    "        self.prototypes = SwaVPrototypes(128, 512, 1)\n",
    "        self.start_queue_at_epoch = 2\n",
    "        self.queues = nn.ModuleList(\n",
    "            [MemoryBankModule(size=(3840, 128)) for _ in range(2)]\n",
    "        )\n",
    "        self.criterion = SwaVLoss()\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        views = batch[0]\n",
    "        high_resolution, low_resolution = views[:2], views[2:]\n",
    "        self.prototypes.normalize()\n",
    "\n",
    "        high_resolution_features = [self._subforward(x) for x in high_resolution]\n",
    "        low_resolution_features = [self._subforward(x) for x in low_resolution]\n",
    "\n",
    "        high_resolution_prototypes = [\n",
    "            self.prototypes(x, self.current_epoch) for x in high_resolution_features\n",
    "        ]\n",
    "        low_resolution_prototypes = [\n",
    "            self.prototypes(x, self.current_epoch) for x in low_resolution_features\n",
    "        ]\n",
    "        queue_prototypes = self._get_queue_prototypes(high_resolution_features)\n",
    "        loss = self.criterion(\n",
    "            high_resolution_prototypes, low_resolution_prototypes, queue_prototypes\n",
    "        )\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optim = torch.optim.Adam(self.parameters(), lr=0.001)\n",
    "        return optim\n",
    "\n",
    "    def _subforward(self, input):\n",
    "        features = self.backbone(input).flatten(start_dim=1)\n",
    "        features = self.projection_head(features)\n",
    "        features = nn.functional.normalize(features, dim=1, p=2)\n",
    "        return features\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def _get_queue_prototypes(self, high_resolution_features):\n",
    "        if len(high_resolution_features) != len(self.queues):\n",
    "            raise ValueError(\n",
    "                f\"The number of queues ({len(self.queues)}) should be equal to the number of high \"\n",
    "                f\"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly.\"\n",
    "            )\n",
    "\n",
    "        # Get the queue features\n",
    "        queue_features = []\n",
    "        for i in range(len(self.queues)):\n",
    "            _, features = self.queues[i](high_resolution_features[i], update=True)\n",
    "            # Queue features are in (num_ftrs X queue_length) shape, while the high res\n",
    "            # features are in (batch_size X num_ftrs). Swap the axes for interoperability.\n",
    "            features = torch.permute(features, (1, 0))\n",
    "            queue_features.append(features)\n",
    "\n",
    "        # If loss calculation with queue prototypes starts at a later epoch,\n",
    "        # just queue the features and return None instead of queue prototypes.\n",
    "        if (\n",
    "            self.start_queue_at_epoch > 0\n",
    "            and self.current_epoch < self.start_queue_at_epoch\n",
    "        ):\n",
    "            return None\n",
    "\n",
    "        # Assign prototypes\n",
    "        queue_prototypes = [\n",
    "            self.prototypes(x, self.current_epoch) for x in queue_features\n",
    "        ]\n",
    "        return queue_prototypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SwaV()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = SwaVTransform()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# we ignore object detection annotations by setting target_transform to return 0\n",
    "def target_transform(t):\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torchvision.datasets.VOCDetection(\n",
    "    \"datasets/pascal_voc\",\n",
    "    download=True,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=128,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")\n",
    "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n",
    "trainer.fit(model=model, train_dataloaders=dataloader)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
